import abc
import plistlib
import socket
import time
from dataclasses import dataclass
from typing import Optional

from construct import (
    Const,
    CString,
    Enum,
    FixedSized,
    GreedyBytes,
    Int16ul,
    Int32ul,
    Padding,
    Prefixed,
    StreamError,
    Struct,
    Switch,
    this,
)

from pymobiledevice3.exceptions import (
    BadCommandError,
    BadDevError,
    ConnectionFailedError,
    ConnectionFailedToUsbmuxdError,
    MuxException,
    MuxVersionError,
    NotPairedError,
)
from pymobiledevice3.osu.os_utils import get_os_utils

usbmuxd_version = Enum(
    Int32ul,
    BINARY=0,
    PLIST=1,
)

usbmuxd_result = Enum(
    Int32ul,
    OK=0,
    BADCOMMAND=1,
    BADDEV=2,
    CONNREFUSED=3,
    NOSUCHSERVICE=4,
    BADVERSION=6,
)

usbmuxd_msgtype = Enum(
    Int32ul,
    RESULT=1,
    CONNECT=2,
    LISTEN=3,
    ADD=4,
    REMOVE=5,
    PAIRED=6,
    PLIST=8,
)

usbmuxd_header = Struct(
    "version" / usbmuxd_version,  # protocol version
    "message" / usbmuxd_msgtype,  # message type
    "tag" / Int32ul,  # responses to this query will echo back this tag
)

usbmuxd_request = Prefixed(
    Int32ul,
    Struct(
        "header" / usbmuxd_header,
        "data"
        / Switch(
            this.header.message,
            {
                usbmuxd_msgtype.CONNECT: Struct(
                    "device_id" / Int32ul,
                    "port" / Int16ul,  # TCP port number
                    "reserved" / Const(0, Int16ul),
                ),
                usbmuxd_msgtype.PLIST: GreedyBytes,
            },
        ),
    ),
    includelength=True,
)

usbmuxd_device_record = Struct(
    "device_id" / Int32ul,
    "product_id" / Int16ul,
    "serial_number" / FixedSized(256, CString("ascii")),
    Padding(2),
    "location" / Int32ul,
)

usbmuxd_response = Prefixed(
    Int32ul,
    Struct(
        "header" / usbmuxd_header,
        "data"
        / Switch(
            this.header.message,
            {
                usbmuxd_msgtype.RESULT: Struct(
                    "result" / usbmuxd_result,
                ),
                usbmuxd_msgtype.ADD: usbmuxd_device_record,
                usbmuxd_msgtype.REMOVE: Struct(
                    "device_id" / Int32ul,
                ),
                usbmuxd_msgtype.PLIST: GreedyBytes,
            },
        ),
    ),
    includelength=True,
)


@dataclass
class MuxDevice:
    devid: int
    serial: str
    connection_type: str

    def connect(self, port: int, usbmux_address: Optional[str] = None) -> socket.socket:
        mux = create_mux(usbmux_address=usbmux_address)
        try:
            return mux.connect(self, port)
        except:
            mux.close()
            raise

    @property
    def is_usb(self) -> bool:
        return self.connection_type == "USB"

    @property
    def is_network(self) -> bool:
        return self.connection_type == "Network"

    def matches_udid(self, udid: str) -> bool:
        return self.serial.replace("-", "") == udid.replace("-", "")


class SafeStreamSocket:
    """wrapper to native python socket object to be used with construct as a stream"""

    def __init__(self, address, family):
        self._offset = 0
        self.sock = socket.socket(family, socket.SOCK_STREAM)
        self.sock.connect(address)

    def send(self, msg: bytes) -> int:
        self._offset += len(msg)
        self.sock.sendall(msg)
        return len(msg)

    def recv(self, size: int) -> bytes:
        msg = b""
        while len(msg) < size:
            chunk = self.sock.recv(size - len(msg))
            self._offset += len(chunk)
            if not chunk:
                raise MuxException("socket connection broken")
            msg += chunk
        return msg

    def close(self) -> None:
        self.sock.close()

    def settimeout(self, interval: float) -> None:
        self.sock.settimeout(interval)

    def setblocking(self, blocking: bool) -> None:
        self.sock.setblocking(blocking)

    def tell(self) -> int:
        return self._offset

    read = recv
    write = send


class MuxConnection:
    # used on Windows
    ITUNES_HOST = ("127.0.0.1", 27015)

    # used for macOS and Linux
    USBMUXD_PIPE = "/var/run/usbmuxd"

    @staticmethod
    def create_usbmux_socket(usbmux_address: Optional[str] = None) -> SafeStreamSocket:
        try:
            if usbmux_address is not None:
                if ":" in usbmux_address:
                    # assume tcp address
                    hostname, port = usbmux_address.split(":")
                    port = int(port)
                    address = (hostname, port)
                    family = socket.AF_INET
                else:
                    # assume unix domain address
                    address = usbmux_address
                    family = socket.AF_UNIX
            else:
                address, family = get_os_utils().usbmux_address
            return SafeStreamSocket(address, family)
        except ConnectionRefusedError as e:
            raise ConnectionFailedToUsbmuxdError() from e

    @staticmethod
    def create(usbmux_address: Optional[str] = None):
        # first attempt to connect with possibly the wrong version header (plist protocol)
        sock = MuxConnection.create_usbmux_socket(usbmux_address=usbmux_address)

        try:
            message = usbmuxd_request.build({
                "header": {"version": usbmuxd_version.PLIST, "message": usbmuxd_msgtype.PLIST, "tag": 1},
                "data": plistlib.dumps({"MessageType": "ReadBUID"}),
            })
            sock.send(message)
            response = usbmuxd_response.parse_stream(sock)

        finally:
            # If we sent a bad request, we should re-create the socket in the correct version this time
            sock.close()
        sock = MuxConnection.create_usbmux_socket(usbmux_address=usbmux_address)

        if response.header.version == usbmuxd_version.BINARY:
            return BinaryMuxConnection(sock)
        elif response.header.version == usbmuxd_version.PLIST:
            return PlistMuxConnection(sock)

        raise MuxVersionError(f"usbmuxd returned unsupported version: {response.version}")

    def __init__(self, sock: SafeStreamSocket):
        self._sock = sock

        # after initiating the "Connect" packet, this same socket will be used to transfer data into the service
        # residing inside the target device. when this happens, we can no longer send/receive control commands to
        # usbmux on same socket
        self._connected = False

        # message sequence number. used when verifying the response matched the request
        self._tag = 1

        self.devices = []

    @abc.abstractmethod
    def _connect(self, device_id: int, port: int):
        """initiate a "Connect" request to target port"""
        pass

    @abc.abstractmethod
    def get_device_list(self, timeout: Optional[float] = None):
        """
        request an update to current device list
        """
        pass

    def connect(self, device: MuxDevice, port: int) -> socket.socket:
        """connect to a relay port on target machine and get a raw python socket object for the connection"""
        self._connect(device.devid, socket.htons(port))
        self._connected = True
        return self._sock.sock

    def close(self):
        """close current socket"""
        self._sock.close()

    def _assert_not_connected(self):
        """verify active state is in state for control messages"""
        if self._connected:
            raise MuxException("Mux is connected, cannot issue control packets")

    def _raise_mux_exception(self, result: int, message: Optional[str] = None) -> None:
        exceptions = {
            int(usbmuxd_result.BADCOMMAND): BadCommandError,
            int(usbmuxd_result.BADDEV): BadDevError,
            int(usbmuxd_result.CONNREFUSED): ConnectionFailedError,
            int(usbmuxd_result.NOSUCHSERVICE): ConnectionFailedError,
            int(usbmuxd_result.BADVERSION): MuxVersionError,
        }
        exception = exceptions.get(result, MuxException)
        raise exception(message)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


class BinaryMuxConnection(MuxConnection):
    """old binary protocol"""

    def __init__(self, sock: SafeStreamSocket):
        super().__init__(sock)
        self._version = usbmuxd_version.BINARY

    def get_device_list(self, timeout: Optional[float] = None):
        """use timeout to wait for the device list to be fully populated"""
        self._assert_not_connected()
        end = time.time() + timeout
        self.listen()
        while time.time() < end:
            self._sock.settimeout(end - time.time())
            try:
                self._receive_device_state_update()
            except (BlockingIOError, StreamError):
                continue
            except OSError as e:
                try:
                    self._sock.setblocking(True)
                    self.close()
                except OSError:
                    pass
                raise MuxException("Exception in listener socket") from e

    def listen(self):
        """start listening for events of attached and detached devices"""
        self._send_receive(usbmuxd_msgtype.LISTEN)

    def _connect(self, device_id: int, port: int):
        self._send({
            "header": {"version": self._version, "message": usbmuxd_msgtype.CONNECT, "tag": self._tag},
            "data": {"device_id": device_id, "port": port},
        })
        response = self._receive()
        if response.header.message != usbmuxd_msgtype.RESULT:
            raise MuxException(f"unexpected message type received: {response}")

        if response.data.result != usbmuxd_result.OK:
            raise self._raise_mux_exception(
                int(response.data.result),
                f"failed to connect to device: {device_id} at port: {port}. reason: {response.data.result}",
            )

    def _send(self, data: dict) -> None:
        self._assert_not_connected()
        self._sock.send(usbmuxd_request.build(data))
        self._tag += 1

    def _receive(self, expected_tag: Optional[int] = None):
        self._assert_not_connected()
        response = usbmuxd_response.parse_stream(self._sock)
        if expected_tag and response.header.tag != expected_tag:
            raise MuxException(f"Reply tag mismatch: expected {expected_tag}, got {response.header.tag}")
        return response

    def _send_receive(self, message_type: int):
        self._send({"header": {"version": self._version, "message": message_type, "tag": self._tag}, "data": b""})
        response = self._receive(self._tag - 1)
        if response.header.message != usbmuxd_msgtype.RESULT:
            raise MuxException(f"unexpected message type received: {response}")

        result = response.data.result
        if result != usbmuxd_result.OK:
            raise self._raise_mux_exception(int(result), f"{message_type} failed: error {result}")

    def _add_device(self, device: MuxDevice):
        self.devices.append(device)

    def _remove_device(self, device_id: int):
        self.devices = [device for device in self.devices if device.devid != device_id]

    def _receive_device_state_update(self):
        response = self._receive()
        if response.header.message == usbmuxd_msgtype.ADD:
            # old protocol only supported USB devices
            self._add_device(MuxDevice(response.data.device_id, response.data.serial_number, "USB"))
        elif response.header.message == usbmuxd_msgtype.REMOVE:
            self._remove_device(response.data.device_id)
        else:
            raise MuxException(f"Invalid packet type received: {response}")


class PlistMuxConnection(BinaryMuxConnection):
    def __init__(self, sock: SafeStreamSocket):
        super().__init__(sock)
        self._version = usbmuxd_version.PLIST

    def listen(self) -> None:
        self._send_receive({"MessageType": "Listen"})

    def get_pair_record(self, serial: str) -> dict:
        # serials are saved inside usbmuxd without '-'
        self._send({"MessageType": "ReadPairRecord", "PairRecordID": serial})
        response = self._receive(self._tag - 1)
        pair_record = response.get("PairRecordData")
        if pair_record is None:
            raise NotPairedError("device should be paired first")
        return plistlib.loads(pair_record)

    def get_device_list(self, timeout: Optional[float] = None) -> None:
        """get device list synchronously without waiting the timeout"""
        self.devices = []
        self._send({"MessageType": "ListDevices"})
        response = self._receive(self._tag - 1)
        device_list = response.get("DeviceList")
        if device_list is None:
            raise MuxException(f"Got an invalid response from usbmux: {response}")
        for response in device_list:
            if response["MessageType"] == "Attached":
                super()._add_device(
                    MuxDevice(
                        response["DeviceID"],
                        response["Properties"]["SerialNumber"],
                        response["Properties"]["ConnectionType"],
                    )
                )
            elif response["MessageType"] == "Detached":
                super()._remove_device(response["DeviceID"])
            else:
                raise MuxException(f"Invalid packet type received: {response}")

    def get_buid(self) -> str:
        """get SystemBUID"""
        self._send({"MessageType": "ReadBUID"})
        return self._receive(self._tag - 1)["BUID"]

    def save_pair_record(self, serial: str, device_id: int, record_data: bytes):
        # serials are saved inside usbmuxd without '-'
        self._send_receive({
            "MessageType": "SavePairRecord",
            "PairRecordID": serial,
            "PairRecordData": record_data,
            "DeviceID": device_id,
        })

    def _connect(self, device_id: int, port: int):
        self._send_receive({"MessageType": "Connect", "DeviceID": device_id, "PortNumber": port})

    def _send(self, data: dict):
        request = {"ClientVersionString": "qt4i-usbmuxd", "ProgName": "pymobiledevice3", "kLibUSBMuxVersion": 3}
        request.update(data)
        super()._send({
            "header": {"version": self._version, "message": usbmuxd_msgtype.PLIST, "tag": self._tag},
            "data": plistlib.dumps(request),
        })

    def _receive(self, expected_tag: Optional[int] = None) -> dict:
        response = super()._receive(expected_tag=expected_tag)
        if response.header.message != usbmuxd_msgtype.PLIST:
            raise MuxException(f"Received non-plist type {response}")
        return plistlib.loads(response.data)

    def _send_receive(self, data: dict):
        self._send(data)
        response = self._receive(self._tag - 1)
        if response["MessageType"] != "Result":
            raise MuxException(f"got an invalid message: {response}")
        if response["Number"] != 0:
            raise self._raise_mux_exception(response["Number"], f"got an error message: {response}")


def create_mux(usbmux_address: Optional[str] = None) -> MuxConnection:
    return MuxConnection.create(usbmux_address=usbmux_address)


def list_devices(usbmux_address: Optional[str] = None) -> list[MuxDevice]:
    mux = create_mux(usbmux_address=usbmux_address)
    try:
        mux.get_device_list(0.1)
        devices = mux.devices
    finally:
        mux.close()
    return devices


def select_device(
    udid: Optional[str] = None, connection_type: Optional[str] = None, usbmux_address: Optional[str] = None
) -> Optional[MuxDevice]:
    """
    select a UsbMux device according to given arguments.
    if more than one device could be selected, always prefer the usb one.
    """
    tmp = None
    for device in list_devices(usbmux_address=usbmux_address):
        if connection_type is not None and device.connection_type != connection_type:
            # if a specific connection_type was desired and not of this one then skip
            continue

        if udid is not None and not device.matches_udid(udid):
            # if a specific udid was desired and not of this one then skip
            continue

        # save best result as a temporary
        tmp = device

        if device.is_usb:
            # always prefer usb connection
            return device

    return tmp


def select_devices_by_connection_type(connection_type: str, usbmux_address: Optional[str] = None) -> list[MuxDevice]:
    """
    select all UsbMux devices by connection type
    """
    tmp = []
    for device in list_devices(usbmux_address=usbmux_address):
        if device.connection_type == connection_type:
            tmp.append(device)

    return tmp
